{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tool calling with mistral.rs\n",
    "\n",
    "Tool calling is a technique to enhance generation by providing the model with functions (tools) which it may call."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize the tools and their schemas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from mistralrs import Runner, ToolChoice, Which, ChatCompletionRequest, Architecture\n",
    "\n",
    "tools = [\n",
    "    json.dumps(\n",
    "        {\n",
    "            \"type\": \"function\",\n",
    "            \"function\": {\n",
    "                \"name\": \"add_2_numbers\",\n",
    "                \"description\": \"Add two numbers, floating point or integer\",\n",
    "                \"parameters\": {\n",
    "                    \"type\": \"string\",\n",
    "                    \"properties\": {\n",
    "                        \"x\": {\n",
    "                            \"type\": \"number\",\n",
    "                            \"description\": \"The first number.\",\n",
    "                        },\n",
    "                        \"y\": {\n",
    "                            \"type\": \"number\",\n",
    "                            \"description\": \"The second number.\",\n",
    "                        },\n",
    "                    },\n",
    "                    \"required\": [\"x\", \"y\"],\n",
    "                },\n",
    "            },\n",
    "        }\n",
    "    )\n",
    "]\n",
    "\n",
    "def add_2_numbers(x, y):\n",
    "    return x + y\n",
    "\n",
    "\n",
    "functions = {\n",
    "    \"add_2_numbers\": add_2_numbers,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up the Runner and the initial messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"Please add 1234513543214 and 1111998778.\",\n",
    "    }\n",
    "]\n",
    "\n",
    "runner = Runner(\n",
    "    which=Which.Plain(\n",
    "        model_id=\"lamm-mit/Bioinspired-Llama-3-1-8B-128k-gamma\",\n",
    "        arch=Architecture.Llama,\n",
    "    ),\n",
    "    in_situ_quant=\"Q4K\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ask the model and give it access with the tools\n",
    "\n",
    "The model will return the chosen tool, if it wants to call it. We just extract the first tool because this is a demo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = runner.send_chat_completion_request(\n",
    "    ChatCompletionRequest(\n",
    "        model=\"llama-3.1\",\n",
    "        messages=messages,\n",
    "        max_tokens=256,\n",
    "        presence_penalty=1.0,\n",
    "        top_p=0.1,\n",
    "        temperature=0.1,\n",
    "        tool_schemas=tools,\n",
    "        tool_choice=ToolChoice.Auto,\n",
    "    )\n",
    ")\n",
    "\n",
    "tool_called = res.choices[0].message.tool_calls[0].function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the tool, give the model what it said and the output of the tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if tool_called.name in functions:\n",
    "    args = json.loads(tool_called.arguments)\n",
    "    result = functions[tool_called.name](**args)\n",
    "    print(f\"Called tool `{tool_called.name}`\")\n",
    "\n",
    "    messages.append(\n",
    "        {\n",
    "            \"role\": \"assistant\",\n",
    "            \"content\": json.dumps({\"name\": tool_called.name, \"parameters\": args}),\n",
    "        }\n",
    "    )\n",
    "\n",
    "    messages.append({\"role\": \"tool\", \"content\": result})\n",
    "\n",
    "    res = runner.send_chat_completion_request(\n",
    "        ChatCompletionRequest(\n",
    "            model=\"llama-3.1\",\n",
    "            messages=messages,\n",
    "            max_tokens=256,\n",
    "            presence_penalty=1.0,\n",
    "            top_p=0.1,\n",
    "            temperature=0.1,\n",
    "            tool_schemas=tools,\n",
    "            tool_choice=ToolChoice.Auto,\n",
    "        )\n",
    "    )\n",
    "    # print(completion.usage)\n",
    "    print(res.choices[0].message.content)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
